//! tcp 远程登录字符设备
//!
//! 实现 telnet 远程网络登录字符设备
//!
//! 注意:
//!
//! 目前由于`CharDevice`的轮询实现, 很难办到多路复用, 因此针对每一个端口的监听,
//! 只能对应一个线程, 无法在一个轮询中处理多路事件, 要是能使用异步编程就完美了.

use std::{
    cell::UnsafeCell,
    collections::{HashMap, VecDeque},
    io::{Read, Write},
    net::SocketAddr,
    sync::{
        atomic::{AtomicBool, AtomicU32, Ordering},
        Arc, LazyLock, Mutex,
    },
};

use manage::SiminkManage;
use mio::{
    net::{TcpListener, TcpStream},
    Events, Interest, Poll, Token,
};
use monitor::Monitor;

use crate::{CharBackendMap, CharDeviceImpl, PollEvent, PollStatus};

static TELNET_PROTOCOL: [u8; 12] = [
    // IAC WILL ECHO
    0xff, 0xfb, 0x01, // IAC WILL Suppress go ahead
    0xff, 0xfb, 0x03, // IAC WILL Binary
    0xff, 0xfb, 0x00, // IAC DO Binary
    0xff, 0xfd, 0x00,
];

#[allow(unused)]
const IAC_EOR: u8 = 239;
#[allow(unused)]
const IAC_SE: u8 = 240;
#[allow(unused)]
const IAC_NOP: u8 = 241;
const IAC_BREAK: u8 = 243;
#[allow(unused)]
const IAC_IP: u8 = 244;
#[allow(unused)]
const IAC_SB: u8 = 250;
const IAC: u8 = 255;

/// telnet 字符设备
pub struct CharDeviceTelnet {
    connect_addr: UnsafeCell<SocketAddr>,
    listener: UnsafeCell<TcpListener>,
    listener_events: UnsafeCell<Events>,
    listener_poll: UnsafeCell<Poll>,
    stream: Mutex<Option<TcpStream>>,
    stream_events: UnsafeCell<Events>,
    stream_poll: UnsafeCell<Poll>,
    do_telnetopt: AtomicU32,
    bufcache: UnsafeCell<VecDeque<u8>>,
    accept_input: AtomicBool,
}

unsafe impl Send for CharDeviceTelnet {}
unsafe impl Sync for CharDeviceTelnet {}

impl CharDeviceTelnet {
    /// 创建一个 telnet 字符设备
    #[allow(clippy::missing_panics_doc)]
    pub fn create(socket_addr: SocketAddr) -> Self {
        Self {
            connect_addr: UnsafeCell::new(socket_addr),
            listener: UnsafeCell::new(TcpListener::bind(socket_addr).unwrap()),
            listener_events: UnsafeCell::new(Events::with_capacity(10)),
            listener_poll: UnsafeCell::new(Poll::new().unwrap()),
            stream: Mutex::new(None),
            stream_events: UnsafeCell::new(Events::with_capacity(10)),
            stream_poll: UnsafeCell::new(Poll::new().unwrap()),
            do_telnetopt: AtomicU32::new(1),
            bufcache: UnsafeCell::new(VecDeque::with_capacity(100)),
            accept_input: AtomicBool::new(false),
        }
    }

    #[allow(non_snake_case)]
    fn tcp_chr_process_IAC_bytes(&self, buf: &mut [u8]) -> (usize, Vec<PollEvent>) {
        let mut j = 0;
        let mut v = vec![];
        for i in 0..buf.len() {
            if self.do_telnetopt.load(Ordering::Relaxed) > 1 {
                if buf[i] == IAC && self.do_telnetopt.load(Ordering::Relaxed) == 2 {
                    // Double IAC means send an IAC
                    if j != i {
                        buf[j] = buf[i];
                    }
                    j += 1;
                    self.do_telnetopt.store(1, Ordering::Relaxed);
                } else {
                    if buf[i] == IAC_BREAK && self.do_telnetopt.load(Ordering::Relaxed) == 2 {
                        // Handle IAC break commands by sending a serial break
                        v.push(PollEvent::Break);
                        self.do_telnetopt.fetch_add(1, Ordering::Relaxed);
                    }
                    self.do_telnetopt.fetch_add(1, Ordering::Relaxed);
                }

                if self.do_telnetopt.load(Ordering::Relaxed) >= 4 {
                    self.do_telnetopt.store(1, Ordering::Relaxed);
                }
            } else if buf[i] == IAC {
                self.do_telnetopt.store(2, Ordering::Relaxed);
            } else {
                if j != i {
                    buf[j] = buf[i];
                }
                j += 1;
            }
        }
        (j, v)
    }
}

impl CharDeviceImpl for CharDeviceTelnet {
    fn prepare(&self) -> std::io::Result<()> {
        let listener = unsafe { &mut *self.listener.get() };
        let poll = unsafe { &mut *self.listener_poll.get() };
        poll.registry().register(listener, Token(0), Interest::READABLE)
    }

    fn wait(&self, timeout: std::time::Duration) -> std::io::Result<bool> {
        let events = unsafe { &mut *self.listener_events.get() };
        let poll = unsafe { &mut *self.listener_poll.get() };
        poll.poll(events, Some(timeout))?;

        let listener = unsafe { &mut *self.listener.get() };
        for event in events.iter() {
            if event.token() == Token(0) && event.is_readable() {
                match listener.accept() {
                    Ok((mut connect, connect_addr)) => {
                        let addr = unsafe { &mut *self.connect_addr.get() };
                        *addr = connect_addr;

                        let mut total = 12;
                        let mut start = 0;
                        while total > 0 {
                            let len = connect.write(&TELNET_PROTOCOL[start..12])?;
                            debug_assert!(total >= len);
                            total -= len;
                            start += len;
                        }

                        let stream_poll = unsafe { &mut *self.stream_poll.get() };
                        stream_poll.registry().register(
                            &mut connect,
                            Token(1),
                            Interest::READABLE,
                        )?;

                        let mut lock = self.stream.lock().unwrap();
                        *lock = Some(connect);
                        return Ok(true);
                    },
                    Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => {
                        continue;
                    },
                    Err(err) => {
                        return Err(err);
                    },
                }
            }
        }
        Ok(false)
    }

    fn poll(&self, timeout: std::time::Duration) -> std::io::Result<PollStatus> {
        let bufcache = unsafe { &mut (*self.bufcache.get()) };
        if !bufcache.is_empty() && self.accept_input.load(Ordering::Relaxed) {
            self.accept_input.store(false, Ordering::Relaxed);
            return Ok(PollStatus::Ready);
        }

        let events = unsafe { &mut *self.stream_events.get() };
        let poll = unsafe { &mut *self.stream_poll.get() };

        poll.poll(events, Some(timeout))?;
        for event in events.iter() {
            if event.token() == Token(1) && event.is_readable() {
                let mut buf = [0u8; 4096];
                let mut lock = self.stream.lock().unwrap();
                let size = lock.as_mut().unwrap().read(&mut buf)?;
                // 连接断开
                if size == 0 {
                    return Ok(PollStatus::ConnectAborted);
                }
                let (size, v) = self.tcp_chr_process_IAC_bytes(&mut buf[0..size]);
                for ch in &buf[0..size] {
                    bufcache.push_back(*ch);
                }
                if v.is_empty() {
                    return Ok(PollStatus::Ready);
                } else {
                    return Ok(PollStatus::Event(v));
                }
            }
        }
        Ok(PollStatus::Timeout)
    }

    fn can_read(&self) -> usize {
        let bufcache = unsafe { &mut (*self.bufcache.get()) };
        bufcache.len()
    }

    fn read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
        let bufcache = unsafe { &mut (*self.bufcache.get()) };
        debug_assert!(!bufcache.is_empty());

        for (idx, this) in buf.iter_mut().enumerate() {
            let ch = bufcache.pop_front();
            if let Some(c) = ch {
                *this = c;
            } else {
                return Ok(idx + 1);
            }
        }
        Ok(buf.len())
    }

    fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
        let mut lock = self.stream.lock().unwrap();
        lock.as_mut().unwrap().write(buf)
    }

    fn notify_accept_input(&self) {
        self.accept_input.store(true, Ordering::Relaxed);
    }

    fn end(&self) -> std::io::Result<()> {
        let lock = self.stream.lock().unwrap();
        if let Some(stream) = lock.as_ref() {
            stream.shutdown(std::net::Shutdown::Both)?;
        }
        Ok(())
    }
}

/// 远程登录字符后端
pub struct TelnetMap {
    #[allow(clippy::type_complexity)]
    telnet_map: Mutex<(
        Option<fn(SocketAddr, Arc<CharBackendMap>, Arc<Monitor>, Arc<SiminkManage>)>,
        HashMap<SocketAddr, Arc<CharBackendMap>>,
        Option<Arc<SiminkManage>>,
        Option<Arc<Monitor>>,
    )>,
}

impl TelnetMap {
    /// 注册远程表通知回调
    #[allow(clippy::missing_panics_doc)]
    pub fn register_notify(
        &self,
        f: fn(SocketAddr, Arc<CharBackendMap>, Arc<Monitor>, Arc<SiminkManage>),
        mon: Arc<Monitor>,
        manage: Arc<SiminkManage>,
    ) {
        let mut lock = self.telnet_map.lock().unwrap();
        lock.0 = Some(f);
        lock.2 = Some(manage);
        lock.3 = Some(mon);
    }

    /// 创建一个新的远程会话
    #[allow(clippy::missing_panics_doc)]
    pub fn create_session(&self, sockaddr: SocketAddr) {
        let mut lock = self.telnet_map.lock().unwrap();
        let map = Arc::new(CharBackendMap::default());
        lock.1.insert(sockaddr, map.clone());
        if let Some(f) = lock.0 {
            f(sockaddr, map, lock.3.clone().unwrap(), lock.2.clone().unwrap());
        }
    }

    /// 返回sock对应的字符后端表
    #[allow(clippy::missing_panics_doc)]
    pub fn char_backend_map(&self, sockaddr: SocketAddr) -> Option<Arc<CharBackendMap>> {
        let lock = self.telnet_map.lock().unwrap();
        let map = lock.1.get(&sockaddr)?;
        Some(map.clone())
    }
}

static TELNET_MAP: LazyLock<Arc<TelnetMap>> = LazyLock::new(|| {
    Arc::new(TelnetMap { telnet_map: Mutex::new((None, HashMap::new(), None, None)) })
});

/// 远程登录字符后端表
pub fn chardevice_telnet_map() -> Arc<TelnetMap> {
    TELNET_MAP.clone()
}
